protected void buildInternal(MultiLabelInstances aTrain) throws Exception { super.buildInternal(aTrain); if (cvkSelection == true) { crossValidate(); } }
protected void crossValidate() throws Exception { try { // the performance for each different k // 不同的k可以造成不同的表现 double[] hammingLoss = new double[cvMaxK]; // 记录每个模型的Hamming Loss for (int i = 0; i < cvMaxK; i++) { hammingLoss[i] = 0; } Instances dataSet = train;//nothing Instance instance; // the hold out instance Instances neighbours; // the neighboring instances double[] origDistances, convertedDistances;//表示orig距离 for (int i = 0; i < dataSet.numInstances(); i++) { if (getDebug() && (i % 50 == 0)) { debug("Cross validating " + i + "/" + dataSet.numInstances() + "\r"); } instance = dataSet.instance(i); neighbours = lnn.kNearestNeighbours(instance, cvMaxK); //instance's k neighbours origDistances = lnn.getDistances(); // gathering the true labels for the instance boolean[] trueLabels = new boolean[numLabels]; for (int counter = 0; counter < numLabels; counter++) { int classIdx = labelIndices[counter]; String classValue = instance.attribute(classIdx).value( (int) instance.value(classIdx)); trueLabels[counter] = classValue.equals("1"); } // calculate the performance metric for each different k for (int j = cvMaxK; j > 0; j--) { convertedDistances = new double[origDistances.length]; System.arraycopy(origDistances, 0, convertedDistances, 0, origDistances.length); double[] confidences = this.getConfidences(neighbours, convertedDistances); boolean[] bipartition = null; switch (extension) { case NONE: // BRknn /**那么选择默认的方法得到最终分类结果 bipartition = results.getBipartition(); */ MultiLabelOutput results; results = new MultiLabelOutput(confidences, 0.5); bipartition = results.getBipartition(); break; case EXTA: // BRknn-a /**那么选择如果最终预测结果不输出任何结果时输出最大置信度的类标的 方法得到最终分类结果 bipartition =labelsFromConfidences2(confidences); */ bipartition = labelsFromConfidences2(confidences); break; case EXTB: // BRknn-b /*选择输出固定类标数目的方法得到预测结果 */ bipartition = labelsFromConfidences3(confidences); break; } double symmetricDifference = 0; // |Y xor Z| for (int labelIndex = 0; labelIndex < numLabels; labelIndex++) { //统计每个样例的预测结果和实际结果的差别 boolean actual = trueLabels[labelIndex]; boolean predicted = bipartition[labelIndex]; if (predicted != actual) { symmetricDifference++; } } hammingLoss[j - 1] += (symmetricDifference / numLabels); neighbours = new IBk().pruneToK(neighbours, convertedDistances, j - 1); } } // Display the results of the cross-validation if (getDebug()) { for (int i = cvMaxK; i > 0; i--) { debug("Hold-one-out performance of " + (i) + " neighbors "); debug("(Hamming Loss) = " + hammingLoss[i - 1] / dataSet.numInstances()); } } // Check through the performance stats and select the best // k value (or the lowest k if more than one best) double[] searchStats = hammingLoss; double bestPerformance = Double.NaN; int bestK = 1; for (int i = 0; i < cvMaxK; i++) { if (Double.isNaN(bestPerformance) || (bestPerformance > searchStats[i])) { bestPerformance = searchStats[i]; bestK = i + 1; } } numOfNeighbors = bestK; if (getDebug()) { System.err.println("Selected k = " + bestK); } } catch (Exception ex) { throw new Error("Couldn't optimize by cross-validation: " + ex.getMessage()); } }
You need to enable Javascript in your browser to edit pages.
help on how to format text
Introduction
Table of Contents
here is the method of classifier builder:
the aTrain is of the training dataset
buildInternal method is from the Class
crossValidate() method will show in the follow
介绍一下,中间的一些代码switch片段是说,选择三种不同的优化方法